import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from enum import Enum
from functools import partial

import numpy as np

from rllm.agents.agent import BaseAgent, Episode, Trajectory
from rllm.engine.rollout.rollout_engine import RolloutEngine
from rllm.environments.base.base_env import BaseEnv


class TerminationReason(Enum):
    MAX_PROMPT_LENGTH_EXCEEDED = "max_prompt_length_exceeded"
    MAX_RESPONSE_LENGTH_EXCEEDED = "max_response_length_exceeded"
    ENV_DONE = "env_done"
    MAX_TURNS_EXCEEDED = "max_turns_exceeded"
    TIMEOUT = "timeout"
    UNKNOWN = "unknown"
    ERROR = "error"


class TerminationEvent(Exception):
    def __init__(self, reason: TerminationReason = TerminationReason.UNKNOWN):
        super().__init__(f"Terminated: {reason}")
        self.reason = reason


class Workflow(ABC):
    def __init__(self, rollout_engine: RolloutEngine, executor: ThreadPoolExecutor, timeout=1e6, gamma=0.0, reward_bonus_coeff=0.0, **kwargs):
        """Initialize the Workflow.

        Args:
            rollout_engine: The rollout engine to use.
            executor: The executor to use.
            timeout: The timeout for the workflow.
            gamma: The discount factor for the workflow.
            reward_bonus_coeff: The reward bonus coefficient for the workflow.
            **kwargs: Additional keyword arguments.
        """
        self.rollout_engine = rollout_engine
        self.executor = executor
        self.timeout = int(timeout)
        self.gamma = gamma
        self.reward_bonus_coeff = reward_bonus_coeff

        self._completed_trajectories: list[Trajectory] = []

    @abstractmethod
    async def run(self, task: dict, uid: str, **kwargs) -> Episode | None:
        """Execute the workflow on a single task

        Args:
            task: The task to execute.
            uid: The unique identifier for the task.
            **kwargs: Additional keyword arguments.

        Returns:
            Episode: The episode generated by the workflow.
        """
        pass

    async def run_with_termination_handling(self, task: dict, uid: str, **kwargs) -> Episode:
        """Wrapper method around workflow.run that handles termination events, errors, timeouts, and post-processing.

        Args:
            task: The task to execute.
            uid: The unique identifier for the task.
            **kwargs: Additional keyword arguments.
        """
        try:
            coro = self.run(task, uid, **kwargs)
            output = await asyncio.wait_for(coro, timeout=self.timeout)
            if output is not None and isinstance(output, Episode):
                return output  # we assume it's already postprocessed
            return self.postprocess_episode(self.collect_trajectories(), TerminationReason.UNKNOWN)
        except asyncio.TimeoutError:
            return self.postprocess_episode(self.collect_trajectories(), TerminationReason.TIMEOUT)
        except TerminationEvent as e:
            return self.postprocess_episode(self.collect_trajectories(), e.reason)
        except Exception as e:
            import traceback

            error_details = {"error_message": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()}
            return self.postprocess_episode(self.collect_trajectories(), TerminationReason.ERROR, error=error_details)

    def commit(self, name: str | None = None, agent: BaseAgent | None = None, trajectory: Trajectory | None = None, reset: bool = False) -> None:
        """Commit a trajectory for training.

        Args:
            name: The name of the trajectory.
            agent: The agent that generated the trajectory.
            trajectory: The trajectory to commit.
            reset: Whether to reset the agent.
        """
        assert agent is not None or trajectory is not None, "Either agent or trajectory must be provided to workflow.commit"
        assert agent is None or trajectory is None, "Only one of agent or trajectory can be provided to workflow.commit"

        traj = agent.trajectory if agent is not None else trajectory
        if name:
            traj.name = name
        if traj.steps:
            self._completed_trajectories.append(deepcopy(traj))

        if agent is not None and reset:
            agent.reset()

    def collect_trajectories(self) -> Episode:
        """Collect the trajectories from the workflow

        Returns:
            Episode: The episode generated by the workflow.
        """

        episode = Episode()

        # Start with completed trajectories
        episode.trajectories.extend(self._completed_trajectories)

        # Track completed trajectory uids
        completed_trajectory_uids = {trajectory.uid for trajectory in self._completed_trajectories}

        # Add trajectories from agents that aren't already in completed trajectories
        for attr_name in dir(self):
            if attr_name.startswith("_"):
                continue
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseAgent) and hasattr(attr_value, "trajectory") and getattr(attr_value.trajectory, "uid", None) not in completed_trajectory_uids and len(attr_value.trajectory.steps) > 0:
                episode.trajectories.append(deepcopy(attr_value.trajectory))

        return episode

    def compute_trajectory_reward(self, trajectory: Trajectory) -> None:
        """
        Compute the trajectory-level reward.
        Default: sum the step rewards

        Args:
            trajectory: The trajectory to compute the reward for.
        """
        trajectory.reward = np.sum([d.reward for d in trajectory.steps])

    def adjust_step_rewards(self, trajectory: Trajectory) -> None:
        """
        Adjust the step-level rewards. Supports reward shaping and discounting
        self.reward_bonus_coeff and self.gamma are 0.0, so no adjustments are made by default.

        Args:
            trajectory: The trajectory to adjust the rewards for.
        """
        # reward shaping
        # s[i].reward = s[i].reward + bonus * (s[i].reward - s[i-1].reward) for i > 0
        if self.reward_bonus_coeff > 0.0:
            raw_rewards = [step.reward for step in trajectory.steps]
            for i in range(1, len(trajectory.steps)):
                trajectory.steps[i].reward += self.reward_bonus_coeff * (raw_rewards[i] - raw_rewards[i - 1])

        # Compute Monte Carlo returns (backward iteration)
        # G_t = R_{t+1} + γ * R_{t+2} + γ² * R_{t+3} + ... + γ^{T-t-1} * R_T
        if self.gamma > 0.0:
            G = 0.0
            for step in reversed(trajectory.steps):
                G = step.reward + self.gamma * G
                step.reward = G  # Replace the reward with MC return

    def assign_episode_correctness(self, episode: Episode) -> None:
        """
        Assign an episode-level correctness flag.
        Default: True if the sum of the trajectory rewards is strictly positive.

        Args:
            episode: The episode to assign the correctness flag to.
        """
        total_reward = 0
        for trajectory in episode.trajectories:
            total_reward += trajectory.reward
        episode.is_correct = total_reward > 0

    def collect_metrics(self, episode: Episode) -> None:
        """
        Collect metrics from the episode.

        Args:
            episode: The episode to collect metrics from.
        """
        metrics = defaultdict(list)
        for traj in episode.trajectories:
            name = traj.name
            metrics[name].append(traj.reward)
        episode.metrics = {f"{k}_acc": float(np.mean(v)) for k, v in metrics.items()}

    def postprocess_episode(self, episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode:
        """Collect and process the trajectories

        Args:
            episode: The episode to postprocess.
            termination_reason: The termination reason for the episode.
            error: The error details for the episode.
        """

        # 1. assign a task id and task
        episode.id = self.uid
        episode.task = self.task

        for trajectory in episode.trajectories:
            # depending on the terminaiton reason, there may be a trajectry with an additional step with empty chat_completions
            # i.e., if it's thrown between agent.update_from_env() and agent.update_from_model()
            if trajectory.steps and not trajectory.steps[-1].chat_completions:
                trajectory.steps.pop()

            # 2. compute trajectory-level rewards
            self.compute_trajectory_reward(trajectory)

            # 3. adjust the step level rewards (e.g., reward shaping or discounting)
            if len(trajectory.steps) > 1:
                self.adjust_step_rewards(trajectory)

        # 4. assign an episode-level correctness flag
        self.assign_episode_correctness(episode)

        # 5. collect additional metrics workflow
        # by default, we report the acc of each agent using the traj reward
        self.collect_metrics(episode)

        # 6. store error details if provided
        if error is not None:
            episode.info["error"] = error

        # 7. assign a termination reason
        episode.termination_reason = termination_reason or TerminationReason.UNKNOWN

        return episode

    def reset(self, task: dict | None = None, uid: str | None = None) -> None:
        """Reset the workflow

        Args:
            task: The task to reset the workflow to.
            uid: The unique identifier for the task.
        """
        # set the uid and task
        self.uid = uid
        self.task = task
        self._completed_trajectories = []

        # reset agents (look for class attributes that are BaseAgent subclasses)
        for attr_name in dir(self):
            if attr_name.startswith("_"):
                continue
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseAgent) and hasattr(attr_value, "reset"):
                attr_value.reset()
                attr_value.trajectory.task = task

        # reset environments (look for class attributes that are BaseEnv subclasses)
        for attr_name in dir(self):
            if attr_name.startswith("_"):
                continue
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseEnv) and hasattr(attr_value, "reset"):
                attr_value.reset(task=task)

    def is_multithread_safe(self) -> bool:
        """Check if the workflow is multithread safe

        Returns:
            bool: True if the workflow is multithread safe, False otherwise.
        """
        for attr_name in dir(self):
            attr_value = getattr(self, attr_name)
            if isinstance(attr_value, BaseEnv) and not attr_value.is_multithread_safe():
                return False
        return True

    async def run_in_executor(self, fn, *args, **kwargs):
        """Run a function in seperate thread pool executor.

        Args:
            fn: The function to run.
            *args: The arguments to pass to the function.
            **kwargs: The keyword arguments to pass to the function.
        """
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(self.executor, partial(fn, *args, **kwargs))
